knitr::opts_chunk$set(echo = TRUE)
options(stringsAsFactors = FALSE)

library(tidyverse)
theme_set(theme_classic())

library(cowplot)
library(gplots)
library(ggrepel)

type = "BirchWood"
expr_file = "BirchWood_transcriptomics.txt.gz"

Preprocess data

samples <- c("Genes", 
             paste0("B4-", sprintf("%02d", 1:28)),
             paste0("B5-", sprintf("%02d", 1:28)),
             paste0("B1-", sprintf("%02d", 1:28)))

# Counts
counts <- read_tsv("../DATA/Counts/BirchWood_salmon.merged.gene_counts.tsv") %>% 
  mutate(`B1-13` = c(0)) %>% 
  mutate(`B1-14` = c(0)) %>% 
  mutate(`B1-16` = c(0))
counts <- counts[, samples]

# Rename samples
samples <- gsub("B4-", "B1.", colnames(counts))
samples <- gsub("B5-", "B2.", samples)
samples <- gsub("B1-", "B3.", samples)
colnames(counts) <- samples

# Mapped reads
lib.size <- colSums(counts[,-1])/1E6
depth <- data.frame(Samples = factor(names(lib.size), levels = names(lib.size)),
           MappedReads = lib.size,
           Type = c("Raw"))

# Smooth counts
counts[,2] <- (counts[,2] + counts[,3] )/2
for (i in 3:(ncol(counts)-1)) {

  t1 <- strsplit(samples[i-1], "\\.")[[1]][1]
  t2 <- strsplit(samples[i], "\\.")[[1]][1]
  t3 <- strsplit(samples[i+1], "\\.")[[1]][1]
  
  if (t1 == t2 & t2 != t3) {
    counts[,i] <- (counts[,i-1] + counts[,i])/2
  } else if (t1 != t2 & t2 == t3) {
    counts[,i] <- (counts[,i] + counts[,i+1])/2
  } else {
    counts[,i] <- (counts[,i-1] + counts[,i] + counts[,i+1])/3
  } 
}
counts[,ncol(counts)] <- (counts[,ncol(counts)-1] + counts[,ncol(counts)])/2


lib.size <- colSums(counts[,-1])/1E6
depth <- rbind(depth, data.frame(Samples = factor(names(lib.size), levels = names(lib.size)),
           MappedReads = lib.size,
           Type = c("Smooth")))

p1 <- depth %>%
  ggplot(aes(x = Samples, y = MappedReads, fill = Samples)) +
  geom_bar(stat = "identity") +
  geom_text(aes(label = format(MappedReads, digits = 2)), hjust = -0.1, size = 2) +
  ylab("Million mapped reads") +
  coord_flip() + theme(legend.position = "none") +
  theme(text = element_text(size = 7)) +
  facet_grid(cols = vars(Type)) 

# Remove outliers
remove_samples <- c("B3.01", "B3.10", "B3.11", "B3.12", "B3.13", "B3.14", "B3.15", 
                    "B3.16", "B3.17", "B3.18", "B3.19", "B3.20", "B3.21", "B3.22", 
                    "B3.23", "B3.24", "B3.25", "B3.26")
samples <- samples[!(samples %in% remove_samples)]
cat("Outlier samples: ", remove_samples, "(", length(remove_samples), ")\n")
## Outlier samples:  B3.01 B3.10 B3.11 B3.12 B3.13 B3.14 B3.15 B3.16 B3.17 B3.18 B3.19 B3.20 B3.21 B3.22 B3.23 B3.24 B3.25 B3.26 ( 18 )
counts <- counts[, samples]

# VST
library(DESeq2)
counts <- counts %>% column_to_rownames(var = "Genes") %>% as.matrix()
expr.wide <- varianceStabilizingTransformation(round(counts))
expr.wide <- expr.wide - min(expr.wide)
expr.wide <- expr.wide %>% as.data.frame() %>% rownames_to_column(var = "Genes")

remove(tmp, counts, lib.size, i, file)

# Filter
expr.long <- expr.wide %>% 
  gather (Sample.names, Expression, -1) %>%
  separate(Sample.names, into = c("Trees", "Samples"), sep = "\\.", remove = FALSE) %>%
  mutate_at("Samples", as.numeric)

expressed_genes <- expr.long %>%
  filter(Expression >= 3) %>%
  group_by(Genes, Trees) %>%
  filter(n() >= 2) %>% # Two samples
  summarise(Expression = max(Expression)) %>%
  group_by(Genes) %>%
  filter(n() >= 2) %>% # Two trees
  dplyr::slice(1) %>% pull(Genes)

cat(length(expressed_genes), "expressed genes with two samples with log2(TPM) > 3 in two trees\n")
## 15433 expressed genes with two samples with log2(TPM) > 3 in two trees
expr.long <- expr.long %>%
  filter(Genes %in% expressed_genes) %>%
  mutate(Sample.names = factor(Sample.names, levels = samples))

expr.wide <- expr.long %>%
  select(-Trees, -Samples) %>%
  spread(Sample.names, Expression) %>%
  as.data.frame() %>%
  column_to_rownames(var = "Genes")

# Expressed genes per sample
p2 <- expr.long %>%
  filter(Expression >= 3) %>%
  group_by(Sample.names) %>%
  summarise(NoGenes = n()) %>%
  mutate(Type = c("Smooth")) %>% 
  ggplot(aes(x = Sample.names, y = NoGenes, fill = Sample.names)) +
  geom_bar(stat = "identity") +
  geom_text(aes(label = format(NoGenes, digits = 2)), hjust = -0.1, size = 2) +
  xlab("Samples") +
  ylab("Expressed genes (VST > 3)") +
  coord_flip() + theme(legend.position = "none") +
  theme(text = element_text(size = 7)) +
  facet_grid(cols = vars(Type))

plot_grid(p1, p2, ncol = 2, rel_widths = c(2,1)) 

# Histogram
expr.long %>%
  ggplot(aes(x = Expression)) +
  geom_histogram(bins = 2 * floor(max(expr.long$Expression)) + 1)

# Print expression table
expr.wide %>% as.data.frame() %>% rownames_to_column(var = "Genes") %>% write_delim(expr_file, delim = "\t")

Hierarchical clustering

n <- ncol(expr.wide)

# Cluster samples
#cdist <- dist(t(expr.wide), method="euclidean")
cdist <- as.dist(1-cor(expr.wide, method="pearson"))
chc <- hclust(cdist,method="ward.D2")
cdendro <- as.dendrogram(chc)

order <- rep(0, n)
h <- colnames(expr.wide)
z <- strsplit(as.character(h), "[.]")
h <- unlist(lapply(z,"[",2))
h <- as.integer(h)
k <- 1
for(i in 1:max(expr.long$Samples)) {
  for(j in 1:n) {
     if (h[j] == i) {
       order[j] = k
      k <- k+1
    }
  }
}
cdendro <- reorder(cdendro, order, agglo.FUN = mean)
chc <- as.hclust(cdendro)
#plot(cdendro)

cnclust = 4
cct <- cutree(chc, k=cnclust)
colsamples <- c("#A3CC51","#51A3CC","#CC5151","#B26F2C")
ccol <- colsamples[cct]

cbreaks = c()
for (i in 1:(cnclust-1)) {
  cbreaks[i] <- sum(cct <= i)
}

# Cluster genes
rdist <- as.dist(1-cor(t(expr.wide), method="pearson"))
rhc <- hclust(rdist,method="ward.D2")
rdendro <- as.dendrogram(rhc)

rnclust = 8
rct <- cutree(rhc, k=rnclust)
#colgenes <- c("lightsteelblue3", "#A3CC51", "#51A3CC", "#92d1f0", "orange", "#B26F2C", "mediumpurple3", "#CC5151")
colgenes <- c("#B26F2C", "orange", "#A3CC51", "lightsteelblue3", "#CC5151" , "#CC5151", "#51A3CC", "mediumpurple3")

c <- 0
prev <- 0
ct <- rct
for (j in nrow(expr.wide):1) {
  g <- rhc$order[j]
  if (rct[g] != prev) {
    c <- c + 1
    prev <- rct[g]
  }
  ct[g] <- c
}
rct <- ct

rcol <- colgenes[rct]

colorR <- colorRampPalette(c("blue","white","red"))(20)

h <- heatmap.2(as.matrix(expr.wide),
        Rowv = rdendro,
        Colv = cdendro,
        colsep = cbreaks,
        sepcolor = "black",
        dendrogram = "both",
        scale = "row",
        margins = c(5, 5), 
        trace = "none",
        cexRow = 0.2,
        cexCol = 0.5,
        labRow = rep("", nrow(expr.wide)), 
        labCol = colnames(expr.wide),
        main = NULL,
        xlab = NULL, ylab = NULL,
        col = colorR,
        key=FALSE, keysize=1, density.info = "none",
        ColSideColors = ccol,
        RowSideColors = rcol
)

cbreaks2 <- table(gsub("\\.\\d\\d", "", colnames(expr.wide)))
cbreaks2 <- cbreaks2[1:(length(cbreaks2)-1)]
for(i in 2:length(cbreaks2)) {
  cbreaks2[i] <- cbreaks2[i] + cbreaks2[i-1]
}

h2 <- heatmap.2(as.matrix(expr.wide),
        Rowv = rdendro,
        Colv = FALSE,
        colsep = cbreaks2,
        sepcolor = "black",
        dendrogram = "row",
        scale = "row",
        margins = c(5, 5), 
        trace = "none",
        cexRow = 0.2,
        cexCol = 0.5,
        labRow = rep("", nrow(expr.wide)), 
        labCol = colnames(expr.wide),
        main = NULL,
        xlab = NULL, ylab = NULL,
        col = colorR,
        key=FALSE, keysize=1, density.info = "none",
        ColSideColors = ccol,
        RowSideColors = rcol
)

# Filter clusters
centroids <- data.frame()
for (i in 1:rnclust) {
  
  centroid <- colMeans(expr.wide[rct == i,])
  
  if (i == 1) {
    centroids <- data.frame(V1 = centroid)
  } else {
    centroids[,i] <- centroid
  }
}

R <- cor(t(expr.wide), centroids)
rct.filtered <- rct
for (i in 1:nrow(R)) {
  if (R[i, rct[i]] < 0.75) {
    rct.filtered[i] <- 0
  }
}


# Plot cluster profiles

gene.clusters <- data.frame(Genes = names(rct.filtered), Clusters = rct.filtered, Colors = rcol)

expr.long <- left_join(expr.long, gene.clusters, by = "Genes")

expr.long <- expr.long %>%
  group_by(Genes) %>%
  mutate(Expression.scaled = (Expression - mean(Expression))/sd(Expression))

tree.select <- "B1"
tree.cct <- cct[grepl(tree.select, names(cct))]
tree.breaks = c()
for (i in 1:(cnclust-1)) {
  tree.breaks[i] <- sum(tree.cct <= i)
}
tree.breaks <- tree.breaks + 0.5

d <- tibble(Trees = c(tree.select), 
           xintercept = tree.breaks)
save(d, gene.clusters, file = "gene.cluster.RData")

plots <- list()
for (i in 1:rnclust) {
  
  cluster <- expr.long %>% filter(Clusters == i, Trees == tree.select)

  centroid <- cluster %>%
    group_by(Samples) %>%
    summarize(Expression.scaled = mean(Expression.scaled))
  
  plots[[i]] <- cluster %>%
    ggplot(aes(x = Samples, y = Expression.scaled, col = Genes)) +
    geom_line() + ylim(range(centroid$Expression.scaled)) +
    scale_color_manual(values = rep("azure3", length(unique(cluster$Genes)))) +
    geom_vline(data = d, mapping = aes(xintercept = xintercept), linetype="dashed", col = "azure2", size=0.5) +
    theme_void() + 
    theme(legend.position="none") +
    theme(panel.background = element_rect(fill = colgenes[i], colour = colgenes[i])) +
    geom_line(data = centroid, mapping = aes(x = Samples, y = Expression.scaled), col = "azure2", size = 2)

}
  
plot_grid(plotlist = plots, ncol = 3)

PCA

pc <- prcomp(t(expr.wide))
  
var.expl <- pc$sdev^2 / sum(pc$sdev^2)
#paste0("Varance explained: ", paste0(format(var.expl[1:5], digits = 2), collapse = " "))

p <- cbind(pc$x, data.frame(Clusters = as.factor(cct),
                            Samples = rownames(pc$x)))

p1 <- ggplot(p, aes(PC1, PC2, col = Clusters)) + 
  xlab(paste0("PC1 (", round(var.expl[1], digits=2),")")) +
  ylab(paste0("PC2 (", round(var.expl[2], digits=2),")")) +
  geom_point() +
  scale_color_manual(values=colsamples) +
  theme(legend.position="none")

p2 <- ggplot(p, aes(PC1, PC2, col = Clusters)) + 
  xlab(paste0("PC1 (", round(var.expl[1], digits=2),")")) +
  ylab(paste0("PC2 (", round(var.expl[2], digits=2),")")) +
  geom_point() +
  geom_label_repel(aes(label = Samples), box.padding=0.01, point.padding=0.01, segment.color='grey50', label.size=0.1, size=2.5) +
  scale_color_manual(values=colsamples) +
  theme(legend.position="none")

plot_grid(p1, p2)

Marker genes

Selected a random gene from each cluster with max expression > 8 and min expresson == 0.

marker_genes <- expr.long %>%
  filter(Trees == tree.select) %>%
  group_by(Genes) %>%
  filter(Clusters > 0) %>%
  mutate(Expression.max = max(Expression)) %>%
  mutate(Expression.min = min(Expression)) %>%
  filter(Expression.max > 8  & Expression.min < 1) %>%
  group_by(Clusters) %>%
  dplyr::slice(1) %>%
  pull(Genes)
  
color <- expr.long %>%
  filter(Trees == tree.select) %>%
  filter(Genes %in% marker_genes) %>%
  group_by(Genes) %>%
  dplyr::slice(1) %>%
  arrange(Genes)

color <- factor(color$Colors, levels = unique(color$Colors))

expr.long %>%
  filter(Trees == tree.select) %>%
  filter(Genes %in% marker_genes) %>%
  ggplot(aes(x = Samples, y = Expression, col = Genes)) +
  geom_line(size = 2) +
  facet_grid(rows = vars(Genes)) +
  theme(legend.position = "none") +
  scale_color_manual(values=as.vector(color))

Orthologs of the marker genes in the AspWood paper

marker_genes_names <- c("SUS6", "CDC2", "EXPA1", "CesA8", "BFN1")
#marker_genes_potri_ids <- c("Potri.004G081300", "Potri.016G142800", "Potri.001G240900", "Potri.004G059600", "Potri.011G044500")
marker_genes_arab_ids <- c("AT1G73370", "AT2G38620", "AT2G28950", "AT4G18780", "AT1G11190")

d <- tibble(Trees = c(tree.select), 
           xintercept = tree.breaks)

# Orthogroups
orthogroups <- read_tsv("../DATA/OrthoGroups/Orthogroups_20240823_clean.tsv.gz", show_col_types = FALSE) %>%
  select(OrthoGroup, Arabidopsis_thaliana, Betula_pendula) %>%
  rowwise(OrthoGroup) %>%
  summarise(Genes = paste(Arabidopsis_thaliana, Betula_pendula, sep = ", ")) %>%
  select(OrthoGroup, Genes) %>%
  separate_rows(Genes, sep = ", ") %>%
  filter(Genes != "") %>%
  rowwise() %>%
  mutate(Species = if_else(grepl("^AT", Genes), "Arab", "Birch")) %>%
  group_by(OrthoGroup, Species, Genes) %>%
  dplyr::slice(1)

plots <- list()
for (i in 1:length(marker_genes_names)) {
  name <- marker_genes_names[i]
  id <- marker_genes_arab_ids[i]

  orth <- orthogroups %>% 
    filter(Genes == id) %>% 
    pull(OrthoGroup)
  
  ids <- orthogroups %>% filter(OrthoGroup == orth, Species == "Birch") %>% pull(Genes)
  
  plots[[length(plots)+1]] <- expr.long %>%
    filter(Genes %in% ids, Trees == tree.select) %>%
    ggplot(aes(x = Samples, y = Expression, col = Genes)) +
    geom_line(size = 2) +
    ggtitle(name) + ylab("Expression (VST)") +
    geom_vline(data = d, mapping = aes(xintercept = xintercept), linetype="dashed", color = "gray68", size=0.5) +
    theme(legend.position = "none")
  
  plots[[length(plots)+1]] <- expr.long %>%
    filter(Genes %in% ids, Trees == tree.select) %>%
    mutate(Expression = 2^Expression) %>%
    ggplot(aes(x = Samples, y = Expression, col = Genes)) +
    geom_line(size = 2) +
    ggtitle(name) + ylab("Expression (CPM)") +
    geom_vline(data = d, mapping = aes(xintercept = xintercept), linetype="dashed", color = "gray68", size=0.5) +
    theme(legend.position = "none")
  
}

plot_grid(plotlist = plots, ncol = 2)

marker_genes_selected <- list()
marker_genes_selected[["SUS6"]] <- c("Bpev01.c0727.g0009.m0001")
marker_genes_selected[["CDC2"]] <- c("Bpev01.c0480.g0058.m0001")
marker_genes_selected[["EXPA1"]] <- c("Bpev01.c0564.g0008.m0001")
marker_genes_selected[["CesA8"]] <- c("Bpev01.c0000.g0006.m0001")
marker_genes_selected[["BFN1"]] <- c("Bpev01.c0690.g0007.m0001")

plots <- list()
for (name in names(marker_genes_selected)) {
  
  ids <- marker_genes_selected[[name]]
  
  plots[[length(plots)+1]] <- expr.long %>%
    filter(Genes %in% ids, Trees == tree.select) %>%
    ggplot(aes(x = Samples, y = Expression, col = Genes)) +
    geom_line(size = 2) +
    ggtitle(name) + ylab("Expression (VST)") +
    geom_vline(data = d, mapping = aes(xintercept = xintercept), linetype="dashed", color = "gray68", size=0.5) +
    theme(legend.position = "none")
  
  plots[[length(plots)+1]] <- expr.long %>%
    filter(Genes %in% ids, Trees == tree.select) %>%
    mutate(Expression = 2^Expression) %>%
    ggplot(aes(x = Samples, y = Expression, col = Genes)) +
    geom_line(size = 2) +
    ggtitle(name) + ylab("Expression (CPM)") +
    geom_vline(data = d, mapping = aes(xintercept = xintercept), linetype="dashed", color = "gray68", size=0.5) +
    theme(legend.position = "none")
  
}

plot_grid(plotlist = plots, ncol = 2)